Draft

Amortized Bayesian Inference with PyTorch

pytorch
variational auto-encoders
amortized bayesian inference
Published

July 25, 2025

Transparency and Heuristics in Latent Space
The cost of generating new sample data can be prohibitive. There is a secondary but different cost which attaches to the ‘construction’ of novel data. Principal Components Analysis can be seen as a technique to optimally reconstruct a complex multivariate data set from a lower level compressed dimensional space. Variational auto-encoders allow us to achieve yet more flexible reconstruction results in non-linear cases. Drawing a new sample from the posterior predictive distribution of Bayesian models similarly supplies us with insight in the variability of realised data. Both methods assume a latent model of the data generating process that aims to leverage a compressed representation of the data. These are different heuristics with different consequences for how we understand the variability in the world. Both encode different degrees of inductive bias throught architectural specification but of the two, Bayesian methods do so transparently.

Reconstruction Error

It’s natural to seek short cuts

import torch
import torchvision.datasets as dsets
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pymc as pm 

Job Satisfaction Data

import numpy as np

# Standard deviations
stds = np.array([0.939, 1.017, 0.937, 0.562, 0.760, 0.524, 
                 0.585, 0.609, 0.731, 0.711, 1.124, 1.001])

n = len(stds)

# Lower triangular correlation values as a flat list
corr_values = [
    1.000,
    .668, 1.000,
    .635, .599, 1.000,
    .263, .261, .164, 1.000,
    .290, .315, .247, .486, 1.000,
    .207, .245, .231, .251, .449, 1.000,
   -.206, -.182, -.195, -.309, -.266, -.142, 1.000,
   -.280, -.241, -.238, -.344, -.305, -.230,  .753, 1.000,
   -.258, -.244, -.185, -.255, -.255, -.215,  .554,  .587, 1.000,
    .080,  .096,  .094, -.017,  .151,  .141, -.074, -.111,  .016, 1.000,
    .061,  .028, -.035, -.058, -.051, -.003, -.040, -.040, -.018,  .284, 1.000,
    .113,  .174,  .059,  .063,  .138,  .044, -.119, -.073, -.084,  .563,  .379, 1.000
]

# Fill correlation matrix
corr_matrix = np.zeros((n, n))
idx = 0
for i in range(n):
    for j in range(i+1):
        corr_matrix[i, j] = corr_values[idx]
        corr_matrix[j, i] = corr_values[idx]
        idx += 1

# Covariance matrix: Sigma = D * R * D
cov_matrix = np.outer(stds, stds) * corr_matrix
#cov_matrix_test = np.dot(np.dot(np.diag(stds), corr_matrix), np.diag(stds))
columns=["JW1","JW2","JW3", "UF1","UF2","FOR", "DA1","DA2","DA3", "EBA","ST","MI"]
corr_df = pd.DataFrame(corr_matrix, columns=columns)

cov_df = pd.DataFrame(cov_matrix, columns=columns)
cov_df

def make_sample(cov_matrix, size, columns):
    sample_df = pd.DataFrame(np.random.multivariate_normal([0]*12, cov_matrix, size=size), columns=columns)
    return sample_df

sample_df = make_sample(cov_matrix, 263, columns)
sample_df.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 -0.169852 -1.192894 -0.795520 0.377262 0.210613 0.007787 0.102942 -0.566327 -0.814464 1.118203 2.504131 0.533175
1 -0.026490 -0.481563 -0.543942 -0.430214 -0.916923 -0.292995 -0.393448 -1.053847 -1.197843 1.073452 0.358204 0.053217
2 -0.256757 -0.995389 -0.564365 1.008702 -0.394069 0.413988 0.029353 -0.368301 -0.068482 -0.067150 0.143072 0.151521
3 -0.224012 0.869592 0.784490 -0.395174 -0.103010 0.795385 -0.100986 0.145035 0.089296 0.331997 0.052695 -0.536667
4 -0.715423 -1.714706 -0.872284 -0.783565 -1.012569 -0.265503 0.247473 0.283395 0.059887 -0.553747 0.395177 -2.008416
data = sample_df.corr()

def plot_heatmap(data, title="Correlation Matrix",  vmin=-.2, vmax=.2, ax=None, figsize=(10, 6), colorbar=True):
    data_matrix = data.values
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(data, cmap='viridis', vmin=vmin, vmax=vmax)

    for i in range(data_matrix.shape[0]):
        for j in range(data_matrix.shape[1]):
            text = ax.text(
                j, i,                      # x, y coordinates
                f"{data_matrix[i, j]:.2f}",       # text to display
                ha="center", va="center",  # center alignment
                color="white" if data_matrix[i,j] < 0.5 else "black"  # contrast color
            )

    ax.set_title(title)
    ax.set_xticklabels(data.columns)  
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticklabels(data.index)  
    ax.set_yticks(np.arange(data.shape[0]))
    if colorbar:
        plt.colorbar(im)

plot_heatmap(data, vmin=-1, vmax=1)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

X = make_sample(cov_matrix, 100, columns=columns)
U, S, VT = np.linalg.svd(X, full_matrices=False)
ranks = [2, 5, 12]
reconstructions = []
for k in ranks:
    X_k = U[:, :k] @ np.diag(S[:k]) @ VT[:k, :]
    reconstructions.append(X_k)

# Plot original and reconstructed matrices
fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(10,15))
axes[0].imshow(X, cmap='viridis')
axes[0].set_title("Original")
axes[0].axis("off")

for ax, k, X_k in zip(axes[1:], ranks, reconstructions):
    ax.imshow(X_k, cmap='viridis')
    ax.set_title(f"Rank {k}")
    ax.axis("off")

plt.suptitle("Reconstruction of Data Using SVD \n various truncation options",fontsize=12, x=.5, y=1.01)
plt.tight_layout()
plt.show()

Variational Auto-Encoders

class NumericVAE(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        
        # ---------- ENCODER ----------
        # First layer: compress input features into a hidden representation
        self.fc1 = nn.Linear(n_features, hidden_dim)
        
        # Latent space parameters (q(z|x)): mean and log-variance
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)       # μ(x)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)   # log(σ^2(x))
        
        # ---------- DECODER ----------
        # First layer: map latent variable z back into hidden representation
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        
        # Output distribution parameters for reconstruction p(x|z)
        # For numeric data, we predict both mean and log-variance per feature
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)        # μ_x(z)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)    # log(σ^2_x(z))

    # ENCODER forward pass: input x -> latent mean, log-variance
    def encode(self, x):
        h = F.relu(self.fc1(x))       # Hidden layer with ReLU
        mu = self.fc_mu(h)            # Latent mean vector
        logvar = self.fc_logvar(h)    # Latent log-variance vector
        return mu, logvar

    # Reparameterization trick: sample z = μ + σ * ε  (ε ~ N(0,1))
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)   # σ = exp(0.5 * logvar)
        eps = torch.randn_like(std)     # ε ~ N(0, I)
        return mu + eps * std           # z = μ + σ * ε

    # DECODER forward pass: latent z -> reconstructed mean, log-variance
    def decode(self, z):
        h = F.relu(self.fc2(z))             # Hidden layer with ReLU
        recon_mu = self.fc_out_mu(h)        # Mean of reconstructed features
        recon_logvar = self.fc_out_logvar(h)# Log-variance of reconstructed features
        return recon_mu, recon_logvar

    # Full forward pass: input x -> reconstructed (mean, logvar), latent params
    def forward(self, x):
        mu, logvar = self.encode(x)            # q(z|x)
        z = self.reparameterize(mu, logvar)    # Sample z from q(z|x)
        recon_mu, recon_logvar = self.decode(z)# p(x|z)
        return (recon_mu, recon_logvar), mu, logvar

    # Sample new synthetic data: z ~ N(0,I), decode to x
    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            # Sample z from standard normal prior
            z = torch.randn(n_samples, self.fc_mu.out_features)
            
            # Decode to get reconstruction distribution parameters
            cont_mu, cont_logvar = self.decode(z)
            
            # Sample from reconstructed Gaussian: μ_x + σ_x * ε
            return cont_mu + torch.exp(0.5 * cont_logvar) * torch.randn_like(cont_mu)
def vae_loss(recon_mu, recon_logvar, x, mu, logvar):
    # Reconstruction loss: Gaussian log likelihood
    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x - recon_mu) ** 2 / recon_var)
    recon_loss = recon_nll.sum(dim=1).mean()  # sum over features, mean over batch

    # KL divergence: D_KL(q(z|x) || p(z)) where p(z)=N(0,I)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss + kl_loss, recon_loss, kl_loss
def prep_data_vae(sample_size=1000):
    sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)

    X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)

    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)

    train_loader = torch.utils.data.DataLoader(X_train, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(X_test, batch_size=32)
    return train_loader, test_loader



def train_vae(vae, optimizer, train, test, patience=10, wait=0, n_epochs=1000):
    best_loss = float('inf')
    losses = []

    for epoch in range(n_epochs):
        vae.train()
        train_loss = 0.0
        
        for batch in train:
            optimizer.zero_grad()

            (recon_mu, recon_logvar), mu, logvar = vae(batch)
            loss, recon_loss, kl_loss = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch.size(0)

        avg_train_loss = train_loss / train.dataset.shape[0]

        # --- Test Loop ---
        vae.eval()
        test_loss = 0.0
        with torch.no_grad():
            for batch in test:
                (recon_mu, recon_logvar), mu, logvar = vae(batch)
                loss, _, _ = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
                test_loss += loss.item() * batch.size(0)
        avg_test_loss = test_loss / test.dataset.shape[0]

        print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")

        if  avg_test_loss < best_loss - 1e-4:
            best_loss, wait = avg_test_loss, 0
            best_state = vae.state_dict()
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                vae.load_state_dict(best_state)  # restore best
                break
        losses.append([avg_train_loss, avg_test_loss, best_loss])

    return vae, pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])

train_500, test_500 = prep_data_vae(500)
vae = NumericVAE(n_features=train_500.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_500, losses_df_500 = train_vae(vae, optimizer, train_500, test_500)


train_1000, test_1000 = prep_data_vae(1000)
vae = NumericVAE(n_features=train_1000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_1000, losses_df_1000 = train_vae(vae, optimizer, train_1000, test_1000)

train_10_000, test_10_000 = prep_data_vae(10_000)
vae = NumericVAE(n_features=train_10_000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_10_000, losses_df_10_000 = train_vae(vae, optimizer, train_10_000, test_10_000)
fig, axs = plt.subplots(1, 3, figsize=(8, 6))
axs=axs.flatten()
losses_df_500[['train_loss', 'test_loss']].plot(ax=axs[0])
losses_df_1000[['train_loss', 'test_loss']].plot(ax=axs[1])
losses_df_10_000[['train_loss', 'test_loss']].plot(ax=axs[2])

axs[0].set_title("Training and Test Losses \n 500 observations");
axs[1].set_title("Training and Test Losses \n 1000 observations");
axs[2].set_title("Training and Test Losses \n 10_000 observations");

def bootstrap_residuals(vae_fit, X_test, sample_df, n_boot=1000):
    recons = []
    resid_array = np.zeros((n_boot, len(sample_df.columns), len(sample_df.columns)))
    for i in range(n_boot):
        recon_data = vae_fit.generate(n_samples=len(X_test))
        reconstructed_df = pd.DataFrame(recon_data, columns=sample_df.columns)
        resid = pd.DataFrame(X_test, columns=sample_df.columns).corr() - reconstructed_df.corr()
        resid_array[i] = resid.values
        recons.append(reconstructed_df)

    avg_resid = resid_array.mean(axis=0)
    bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df.columns, index=sample_df.columns)
    return bootstrapped_resids

bootstrapped_resids_500 = bootstrap_residuals(vae_fit_500, pd.DataFrame(test_500.dataset, columns=sample_df.columns), sample_df)

bootstrapped_resids_1000 = bootstrap_residuals(vae_fit_1000, pd.DataFrame(test_1000.dataset, columns=sample_df.columns), sample_df)

bootstrapped_resids_10_000 = bootstrap_residuals(vae_fit_10_000, pd.DataFrame(test_10_000.dataset, columns=sample_df.columns), sample_df)


fig, axs = plt.subplots(3, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_500, title="""Expected Correlation Residuals for 500 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)

plot_heatmap(bootstrapped_resids_1000, title="""Expected Correlation  Residuals  for 1000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)

plot_heatmap(bootstrapped_resids_10_000, title="""Expected Correlation  Residuals  for 10,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[2], colorbar=True, vmin=-.25, vmax=.25)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

sample_df_missing = sample_df.copy()

# Randomly pick 5% of the total elements
mask_remove = np.random.rand(*sample_df_missing.shape) < 0.05

# Set those elements to NaN
sample_df_missing[mask_remove] = np.nan
sample_df_missing.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 -0.169852 -1.192894 -0.795520 NaN 0.210613 0.007787 0.102942 -0.566327 -0.814464 1.118203 2.504131 0.533175
1 -0.026490 -0.481563 -0.543942 -0.430214 -0.916923 -0.292995 -0.393448 -1.053847 -1.197843 1.073452 0.358204 0.053217
2 -0.256757 -0.995389 -0.564365 1.008702 -0.394069 0.413988 0.029353 -0.368301 -0.068482 -0.067150 NaN 0.151521
3 -0.224012 0.869592 0.784490 -0.395174 -0.103010 0.795385 -0.100986 0.145035 0.089296 0.331997 0.052695 -0.536667
4 -0.715423 -1.714706 -0.872284 -0.783565 -1.012569 -0.265503 0.247473 0.283395 0.059887 -0.553747 0.395177 -2.008416
import torch
import torch.nn as nn
import torch.nn.functional as F


class MissingDataDataset(Dataset):
    def __init__(self, x, mask):
        # x and mask are tensors of same shape
        self.x = x
        self.mask = mask
        
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        return self.x[idx], self.mask[idx]

def prep_data_vae_missing(sample_size=1000, batch_size=32, missing_frac=0.2):
    sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)

    total_values = sample_df.size
    num_nans = int(total_values * missing_frac)

    # Choose random flat indices
    nan_indices = np.random.choice(total_values, num_nans, replace=False)

    # Convert flat indices to (row, col)
    rows, cols = np.unravel_index(nan_indices, sample_df.shape)

    # Set the values to NaN
    sample_df.values[rows, cols] = np.nan

    X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)

    # Mask: 1=observed, 0=missing
    mask_train = ~pd.DataFrame(X_train).isna()
    mask_test = ~pd.DataFrame(X_test).isna()

    # Tensors (keep NaNs for missing values)
    x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    mask_train_tensor = torch.tensor(mask_train.values, dtype=torch.float32)

    x_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    mask_test_tensor = torch.tensor(mask_test.values, dtype=torch.float32)

    train_dataset = MissingDataDataset(x_train_tensor, mask_train_tensor)
    test_dataset = MissingDataDataset(x_test_tensor, mask_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def vae_loss_missing(recon_mu, recon_logvar, x_filled, mu, logvar, mask):
    """
    VAE loss that skips missing values (NaNs) in x for the reconstruction term.
    """

    # Reconstruction loss (Gaussian NLL) only on observed values
    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x_filled - recon_mu) ** 2 / recon_var)

    # Apply mask and normalize by number of observed features per sample
    recon_nll = recon_nll * mask  # zero-out missing features
    obs_counts = mask.sum(dim=1).clamp(min=1)  # avoid division by 0
    recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()

    # KL divergence as usual
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss, kl_loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class NumericVAE_missing(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        self.n_features = n_features

        # ---------- Learnable Imputation ----------
        # One learnable parameter per feature for missing values
        self.missing_embeddings = nn.Parameter(torch.zeros(n_features))

        # ---------- ENCODER ----------
        self.fc1_x = nn.Linear(n_features, hidden_dim)

        # Stronger mask encoder: 2-layer MLP
        self.fc1_mask = nn.Sequential(
            nn.Linear(n_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Combine feature and mask embeddings
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # ---------- DECODER ----------
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)

    def encode(self, x, mask):
        # Impute missing values with learnable parameters
        x_filled = torch.where(
            torch.isnan(x),
            self.missing_embeddings.expand_as(x),
            x
        )

        # Encode features and mask separately
        h_x = F.relu(self.fc1_x(x_filled))
        h_mask = self.fc1_mask(mask)

        # Combine embeddings
        h = h_x + h_mask

        # Latent space
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        recon_mu = self.fc_out_mu(h)
        recon_logvar = self.fc_out_logvar(h)
        return recon_mu, recon_logvar

    def forward(self, x, mask):
        mu, logvar = self.encode(x, mask)
        z = self.reparameterize(mu, logvar)
        recon_mu, recon_logvar = self.decode(z)
        return (recon_mu, recon_logvar), mu, logvar

    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.fc_mu.out_features)
            recon_mu, recon_logvar = self.decode(z)
            return recon_mu + torch.exp(0.5 * recon_logvar) * torch.randn_like(recon_mu)
def vae_loss_missing(recon_mu, recon_logvar, x, mu, logvar, mask):
    # Fill missing values with 0 just for loss computation
    x_filled = torch.where(mask.bool(), x, torch.zeros_like(x))

    recon_var = torch.exp(recon_logvar)
    recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) +
                       (x_filled - recon_mu) ** 2 / recon_var)

    # Mask out missing values
    recon_nll = recon_nll * mask
    obs_counts = mask.sum(dim=1).clamp(min=1)
    recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()

    # KL divergence
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl_loss = kl_div.mean()

    return recon_loss, kl_loss

def beta_annealing(epoch, max_beta=1.0, anneal_epochs=100):

    beta = min(max_beta, max_beta * epoch / anneal_epochs)
    return beta
def fit_vae_missing(vae_missing, train_loader, test_loader, optimizer, patience=10, wait=0, n_epochs=1000):
    best_loss = float('inf')
    losses = []

    for epoch in range(n_epochs):
        beta = beta_annealing(epoch, max_beta=1.0, anneal_epochs=10)
        vae_missing.train()
        
        train_loss = 0
        for x_batch, mask_batch in train_loader:
            optimizer.zero_grad()
            (recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
            recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
            loss = recon_loss + beta * kl_loss
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x_batch.size(0)

        avg_train_loss = train_loss / len(train_loader.dataset)

        # --- Validation ---
        vae_missing.eval()
        test_loss = 0.0
        with torch.no_grad():
            for x_batch, mask_batch in test_loader:
                (recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
                recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
                loss = recon_loss + kl_loss
                test_loss += loss.item() * x_batch.size(0)
        avg_test_loss = test_loss / len(test_loader.dataset)

        print(f"Epoch {epoch+1}/{n_epochs} | Train: {avg_train_loss:.4f} | Test: {avg_test_loss:.4f}")

        # Early stopping
        if avg_test_loss < best_loss - 1e-4:
                best_loss, wait = avg_test_loss, 0
                best_state = vae_missing.state_dict()
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                vae_missing.load_state_dict(best_state)  # restore best
                break
        losses.append([avg_train_loss, avg_test_loss, best_loss])
    losses_df = pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])    

    return vae_missing, losses_df


train_loader_5000, test_loader_5000 = prep_data_vae_missing(5000, batch_size=32)
vae_missing_5000 = NumericVAE_missing(n_features=next(iter(train_loader_5000))[0].shape[1], latent_dim=12)
optimizer = optim.Adam(vae_missing_5000.parameters(), lr=1e-4)

fit_vae_missing_5000, losses_df_missing_5000 = fit_vae_missing(vae_missing_5000, train_loader_5000, test_loader_5000, optimizer)



train_loader_25_000, test_loader_25_000 = prep_data_vae_missing(25_000, batch_size=32)
vae_missing_25_000 = NumericVAE_missing(n_features=next(iter(train_loader_25_000))[0].shape[1], latent_dim=12)
optimizer = optim.Adam(vae_missing_25_000.parameters(), lr=1e-4)

fit_vae_missing_25_000, losses_df_missing_25_000 = fit_vae_missing(vae_missing_25_000, train_loader_25_000, test_loader_25_000, optimizer)
bootstrapped_resids_5000 = bootstrap_residuals(fit_vae_missing_5000, pd.DataFrame(test_loader_5000.dataset.x, columns=sample_df.columns), sample_df)

bootstrapped_resids_25_000 = bootstrap_residuals(fit_vae_missing_25_000, pd.DataFrame(test_loader_25_000.dataset.x, columns=sample_df.columns), sample_df)


fig, axs = plt.subplots(2, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_5000, title="""Expected Correlation Residuals for 5000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)

plot_heatmap(bootstrapped_resids_25_000, title="""Expected Correlation  Residuals  for 25,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Bayesian Inference

def make_pymc_model(sample_df):
    coords = {'features': sample_df.columns,
            'features1': sample_df.columns ,
            'obs': range(len(sample_df))}

    with pm.Model(coords=coords) as model:
        # Priors
        mus = pm.Normal("mus", 0, 1, dims='features')
        chol, _, _ = pm.LKJCholeskyCov("chol", n=12, eta=1.0, sd_dist=pm.HalfNormal.dist(1))
        cov = pm.Deterministic('cov', pm.math.dot(chol, chol.T), dims=('features', 'features1'))

        pm.MvNormal('likelihood', mus, cov=cov, observed=sample_df.values, dims=('obs', 'features'))
        
        idata = pm.sample_prior_predictive()
        idata.extend(pm.sample(random_seed=120))
        pm.sample_posterior_predictive(idata, extend_inferencedata=True)

    return idata, model 

idata, model = make_pymc_model(sample_df)
pm.model_to_graphviz(model)

import arviz as az

expected_corr = pd.DataFrame(az.summary(idata, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

idata_missing, model_missing = make_pymc_model(sample_df_missing)
pm.model_to_graphviz(model_missing)

expected_corr = pd.DataFrame(az.summary(idata_missing, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Citation

BibTeX citation:
@online{forde2025,
  author = {Forde, Nathaniel},
  title = {Amortized {Bayesian} {Inference} with {PyTorch}},
  date = {2025-07-25},
  langid = {en},
  abstract = {The cost of generating new sample data can be prohibitive.
    There is a secondary but different cost which attaches to the
    “construction” of novel data. Principal Components Analysis can be
    seen as a technique to optimally reconstruct a complex multivariate
    data set from a lower level compressed dimensional space.
    Variational auto-encoders allow us to achieve yet more flexible
    reconstruction results in non-linear cases. Drawing a new sample
    from the posterior predictive distribution of Bayesian models
    similarly supplies us with insight in the variability of realised
    data. Both methods assume a latent model of the data generating
    process that aims to leverage a compressed representation of the
    data. These are different heuristics with different consequences for
    how we understand the variability in the world. Both encode
    different degrees of inductive bias throught architectural
    specification but of the two, Bayesian methods do so transparently.}
}
For attribution, please cite this work as:
Forde, Nathaniel. 2025. “Amortized Bayesian Inference with PyTorch.” July 25, 2025.